-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[ROCDL] Added matrix load-transpose ops for gfx1250+ #165564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f678396 to
895975f
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCDL is a trivial wrapper around LLVM intrinsics and should match their definitions, polymorphism, etc. as much as possible
| def ROCDL_GlobalLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96, IType<I32>>>; | ||
| def ROCDL_GlobalLoadTr8_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128, IType<I16>>>; | ||
| //def ROCDL_GlobalLoadTr8_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, FType<F16>>>; | ||
| //def ROCDL_GlobalLoadTr8_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, BF16Type>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strong reject of having "f16" and "bf16" and so on variants.
Just make it variadic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Removed constraints from the output type.
0b493c5 to
c63c02f
Compare
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Ravil Dorozhinskii (ravil-mobile) ChangesThis patch adds load-transpose instructions for gfx1250+ arch to ROCDL. Note, this is work in progress but I'd like to share the ideas here and hope to get some comments. Full diff: https://github.com/llvm/llvm-project/pull/165564.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..a6666d7379404 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
-def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
-
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
+
+//===---------------------------------------------------------------------===//
+// Glb/DS load-transpose intrinsics (available in GFX1250+)
+
+class AddrKind<string n, int s> {
+ string name = n;
+ int space = s;
+}
+def GlobalAddrKind : AddrKind<"global", 1>;
+def DSAddrKind : AddrKind<"ds", 3>;
+
+class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> {
+ AddrKind addrKind = kind;
+ string inBits = !cast<string>(inElemBits);
+ string outBits = !cast<string>(outElemBits);
+ string inBitsEnc = !if(!eq(addrKind.space, 1),
+ !if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits);
+ string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
+}
+
+class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
+ ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
+
+ dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr);
+ let arguments = !con(args, baseArgs);
+ let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory or ds to registers (available in gfx1250+).";
+ let description = [{
+ Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory,
+ transpose data between row-major and column-major order,
+ and store the result into a }] # meta.outBits # [{-bit vector register.
+
+ Available in gfx1250+.
+ }];
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+ let extraClassDefinition = [{
+ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
+ return {getPtr()};
+ }
+ }];
+}
+
+def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>;
+def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>;
+def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>;
+def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>;
+
+def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>;
+def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>;
+def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>;
+def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>;
+
//===---------------------------------------------------------------------===//
// Load to LDS intrinsic (available in GFX9 and GFX10)
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index e703600c71c8e..8bcf2bb06a4b8 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -650,6 +650,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: @rocdl.load.tr.ops
+ // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
+ // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+ // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+ // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : <1> -> vector<3xi32>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xi16>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xf16>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xbf16>
+ // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+ // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+ // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : <3> -> vector<3xi32>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xi16>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xf16>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xbf16>
+ // CHECK: llvm.return
+
+ rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+ rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+ rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16>
+
+ rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+ rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+ rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16>
+ llvm.return
+}
+
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
// CHECK-LABEL @rocdl.load.to.lds
//CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8a848221a50dd..0a556f5b5a845 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1028,6 +1028,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: rocdl.load.tr.ops
+ // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]])
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]])
+
+ rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+ rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+ rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16>
+
+ rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+ rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+ rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16>
+ llvm.return
+}
+
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
//CHECK: call void @llvm.amdgcn.load.to.lds.p7
rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>
|
mlir/test/Dialect/LLVMIR/rocdl.mlir
Outdated
| rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32> | ||
| rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32> | ||
| rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32> | ||
| rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16> | ||
| rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16> | ||
| rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16> | ||
|
|
||
| rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32> | ||
| rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32> | ||
| rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32> | ||
| rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16> | ||
| rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16> | ||
| rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop spaces in vector dims. Also in the other file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, done
c63c02f to
24aa2e7
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall seems fine to me, I'd wait for @kuhar to take another look
| let description = [{ | ||
| Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory, | ||
| transpose data between row-major and column-major order, | ||
| and store the result into a }] # meta.outBits # [{-bit vector register. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll note these instructions seem generally underdocumented. While this PR may not be the right place for them, can we make a plan for surfacing their exact semantics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krzysz00, do you mean in the future, right?
| class AddrKind<string n, int s> { | ||
| string name = n; | ||
| int space = s; | ||
| } | ||
| def GlobalAddrKind : AddrKind<"global", 1>; | ||
| def DSAddrKind : AddrKind<"ds", 3>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
mlir/test/Dialect/LLVMIR/rocdl.mlir
Outdated
| llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { | ||
| // CHECK-LABEL: @rocdl.load.tr.ops | ||
| // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>) | ||
| // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do these get printed without the !llvm.ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fairly typical in tablegen - if the !foo or #foo in statically known, it gets omitted unless you explicitly stick a qualified(...) around the printed term
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I think it would be nicer if it was qualified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Co-authored-by: Krzysztof Drewniak <[email protected]>
b136481 to
4a15ecc
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/32805 Here is the relevant piece of the build log for the reference |
This patch adds load-transpose instructions for gfx1250+ arch to ROCDL. Note, this is work in progress but I'd like to share the ideas here and hope to get some comments.